-
Notifications
You must be signed in to change notification settings - Fork 190
[5477976] Fix: issue removing Q/DQ nodes around custom ops with constant inputs #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[5477976] Fix: issue removing Q/DQ nodes around custom ops with constant inputs #296
Conversation
WalkthroughAdds defensive lookups and branching when rewiring Q/DQ edges to avoid KeyError by using Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant Utils as remove_input_dq_and_output_q
participant Prev as PrevProducer (optional)
participant Q as Q Node
participant Consumer as Consumer
participant DQ as DQ Node
participant DownDQ as Downstream DQ (optional)
Caller->>Utils: invoke on Q node
note right of Utils: Input-side handling
alt Prev exists
Utils->>Consumer: set Consumer.input[cons_idx] = Prev.output[0]
else Prev missing
Utils->>Consumer: keep Consumer.input[cons_idx] = Q.input[0]
end
note right of Utils: Output-side handling
alt Downstream DQ exists
Utils->>DownDQ: set DownDQ.input[...] = Producer.output[0]
else Downstream DQ missing
Utils->>DQ: set DQ.input[0] = Producer.output[0]
end
Utils-->>Caller: return
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
ef10a57 to
e27a2b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)
793-795: Good null-safe fallback; consider pruning now-dead Q when deleting both Q/DQ.Using tensor_producers.get(...) avoids KeyError and correctly falls back to the original tensor when there’s no upstream producer. When you hit the “delete both” path, you likely leave an orphan Q node if it only fed this DQ. You can opportunistically mark that Q for removal.
Apply within-range change plus supporting setup:
- q_node_prev = tensor_producers.get(q_node.input[0], None) - consumer.input[cons_idx] = q_node_prev.output[0] if q_node_prev else q_node.input[0] + q_node_prev = tensor_producers.get(q_node.input[0], None) + consumer.input[cons_idx] = q_node_prev.output[0] if q_node_prev else q_node.input[0] + # If Q had only this DQ consumer, schedule Q for deletion too + if len(tensor_consumers.get(q_node.output[0], [])) == 1: + q_indices.append(q_index_map[q_node.name])Outside this hunk, initialize once near the top of the function:
# Build fast lookup maps for later deletions q_index_map = {n.name: idx for idx, n in q_nodes} dq_index_map = {n.name: idx for idx, n in dq_nodes}
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
modelopt/onnx/quantization/qdq_utils.py(2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
dfd8f6a to
505f715
Compare
|
Note Unit test generation is an Early Access feature. Expect some limitations and changes as we gather feedback and continue to improve it. Generating unit tests... This may take up to 20 minutes. |
|
Here are the copyable unit test edits: Copyable Editstests/unit/onnx/test_qdq_utils.py@@ -16,6 +16,7 @@
import numpy as np
import pytest
from onnx import TensorProto, helper, numpy_helper
+from onnx import checker
from modelopt.onnx.quantization.qdq_utils import _cast_fp4, _cast_fp8, quantize_weights_to_int4
@@ -333,3 +334,129 @@
assert result.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
assert result.shape == expected_array.shape
assert np.all(result == expected_array)
+
+# ---------------------------------------------------------------------------
+# Additional tests appended by CI to broaden coverage for qdq_utils
+# ---------------------------------------------------------------------------
+
+def create_minimal_model_dq_matmul_no_reshape_transpose():
+ """Create a model with DequantizeLinear feeding MatMul directly (no Reshape/Transpose)."""
+ # weight: int8, shape (K, N) = (8, 16)
+ w_data = np.random.randint(-8, 8, size=(8, 16), dtype=np.int8)
+ w = numpy_helper.from_array(w_data, "w_no_rt")
+ # scale per-row
+ s_data = np.random.uniform(0.1, 1.0, size=(8, 1)).astype(np.float32)
+ s = numpy_helper.from_array(s_data, "s_no_rt")
+
+ # input: (M, K) = (?, 8)
+ inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [None, 8])
+
+ dq = helper.make_node("DequantizeLinear", ["w_no_rt", "s_no_rt"], ["dq_w"], name="dq_no_rt")
+ mm = helper.make_node("MatMul", ["inp", "dq_w"], ["out"], name="mm_no_rt")
+
+ graph = helper.make_graph(
+ [dq, mm],
+ "g_no_rt",
+ [inp],
+ [helper.make_tensor_value_info("out", TensorProto.FLOAT, [None, 16])],
+ initializer=[w, s],
+ )
+ return helper.make_model(graph)
+
+
+class TestQuantizeWeightsToInt4_Additional:
+ def test_model_checker_valid_after_quantization(self):
+ """Quantized model should pass ONNX checker validation."""
+ model = create_test_model_with_dq_reshape_transpose_matmul()
+ qmodel = quantize_weights_to_int4(model)
+ # Validate structural correctness
+ checker.check_model(qmodel)
+
+ def test_idempotency_of_quantization(self):
+ """Applying quantization twice should be stable (no additional structural changes)."""
+ model = create_test_model_with_dq_reshape_transpose_matmul()
+ q1 = quantize_weights_to_int4(model)
+ q2 = quantize_weights_to_int4(q1)
+
+ # Compare key invariants: node op types multiset and initializer (name, dtype) pairs
+ ops1 = sorted([n.op_type for n in q1.graph.node])
+ ops2 = sorted([n.op_type for n in q2.graph.node])
+ assert ops1 == ops2
+
+ inits1 = sorted([(i.name, i.data_type) for i in q1.graph.initializer])
+ inits2 = sorted([(i.name, i.data_type) for i in q2.graph.initializer])
+ assert inits1 == inits2
+
+ # MatMul should still consume DequantizeLinear output directly
+ mm1 = next(n for n in q1.graph.node if n.op_type == "MatMul")
+ dq1 = next(n for n in q1.graph.node if n.op_type == "DequantizeLinear")
+ mm2 = next(n for n in q2.graph.node if n.op_type == "MatMul")
+ dq2 = next(n for n in q2.graph.node if n.op_type == "DequantizeLinear")
+ assert mm1.input[1] == dq1.output[0]
+ assert mm2.input[1] == dq2.output[0]
+
+ def test_no_pattern_present_is_handled_gracefully(self):
+ """When Reshape/Transpose pattern is absent, quantization should still succeed and keep graph valid."""
+ model = create_minimal_model_dq_matmul_no_reshape_transpose()
+ qmodel = quantize_weights_to_int4(model)
+
+ # Still valid ONNX
+ checker.check_model(qmodel)
+
+ # Weight initializer should be INT4 after quantization
+ w_init = next(i for i in qmodel.graph.initializer if i.name == "w_no_rt")
+ assert w_init.data_type == TensorProto.INT4
+
+ # Graph should still contain DequantizeLinear and MatMul; no Reshape/Transpose should appear
+ node_types = [n.op_type for n in qmodel.graph.node]
+ assert "DequantizeLinear" in node_types
+ assert "MatMul" in node_types
+ assert "Reshape" not in node_types
+ assert "Transpose" not in node_types
+
+
+class TestCastFunctions_Additional:
+ def test_cast_fp8_empty_and_specials(self):
+ """_cast_fp8 should handle empty arrays and special values without error and with correct dtype/shape."""
+ # Empty
+ arr_empty = np.array([], dtype=np.float32)
+ out_empty = _cast_fp8(arr_empty)
+ assert out_empty.dtype == np.dtype((np.uint8, [("e4m3fn", "u1")]))
+ assert out_empty.shape == (0,)
+
+ # Specials: NaN and Inf should not crash
+ arr_specials = np.array([np.nan, np.inf, -np.inf, 0.0, -0.0], dtype=np.float32)
+ out_specials = _cast_fp8(arr_specials)
+ assert out_specials.dtype == np.dtype((np.uint8, [("e4m3fn", "u1")]))
+ assert out_specials.shape == arr_specials.shape
+
+ def test_cast_fp8_random_bulk_shape_and_bounds(self):
+ """_cast_fp8 should preserve shape and produce uint8 payloads."""
+ arr = np.random.randn(7, 3, 5).astype(np.float32) * 10.0
+ out = _cast_fp8(arr)
+ assert out.shape == arr.shape
+ assert out.dtype == np.dtype((np.uint8, [("e4m3fn", "u1")]))
+ # Values are uint8-coded; verify range implicitly via dtype and via max/min checks on the raw view
+ raw = out.view(np.uint8)
+ assert raw.min() >= 0 and raw.max() <= 255
+
+ def test_cast_fp4_empty_and_specials(self):
+ """_cast_fp4 should handle empty arrays and special values without error and with correct dtype/shape."""
+ arr_empty = np.array([], dtype=np.float32)
+ out_empty = _cast_fp4(arr_empty)
+ assert out_empty.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
+ assert out_empty.shape == (0,)
+
+ arr_specials = np.array([np.nan, np.inf, -np.inf, 0.0, -0.0], dtype=np.float32)
+ out_specials = _cast_fp4(arr_specials)
+ assert out_specials.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
+ assert out_specials.shape == arr_specials.shape
+
+ def test_cast_fp4_random_bulk_shape_and_bounds(self):
+ """_cast_fp4 should preserve shape and produce uint8 payloads."""
+ arr = (np.random.rand(4, 4, 4).astype(np.float32) - 0.5) * 6.0
+ out = _cast_fp4(arr)
+ assert out.shape == arr.shape
+ assert out.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
+ raw = out.view(np.uint8)
+ assert raw.min() >= 0 and raw.max() <= 255tests/unit/onnx/test_trt_utils.pyThis is a new file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/trt_utils.py (1)
294-297: Possible None dereference when appending to intermediate_generated_filesintermediate_generated_files is typed as Optional, but append is unconditional and will raise if None. Guard it.
- intermediate_generated_files.append(static_shaped_onnx_path) # type: ignore[union-attr] + if intermediate_generated_files is not None: + intermediate_generated_files.append(static_shaped_onnx_path)
🧹 Nitpick comments (1)
modelopt/onnx/trt_utils.py (1)
419-424: Mirror quantize guard for casts and clarify docstring
- Only add
custom_ops_to_cast[op_type]wheninp_precision_castorout_precision_castis non-empty for symmetry with quantization.- Update
interpret_trt_plugins_precision_flagdocstring to state that op types with no quantizable indices are omitted fromcustom_ops_to_quantize.- Consumers already guard map usage via membership checks or iterate keys—no direct indexing found.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
modelopt/onnx/quantization/qdq_utils.py(2 hunks)modelopt/onnx/trt_utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/quantization/qdq_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/onnx/trt_utils.py (1)
419-424: Good guard: avoid no-op entries in custom_ops_to_quantizeOnly adding the op when there is at least one int8/fp8 I/O index prevents empty keys and simplifies downstream handling. This aligns with the PR’s intent around edge cases.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #296 +/- ##
==========================================
- Coverage 73.94% 73.93% -0.01%
==========================================
Files 172 172
Lines 17405 17408 +3
==========================================
+ Hits 12870 12871 +1
- Misses 4535 4537 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please add/update any related tests.
d8089c5 to
032e470
Compare
Signed-off-by: gcunhase <[email protected]>
032e470 to
fabe5b8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (5)
modelopt/onnx/trt_utils.py (5)
411-415: Mirror the guard for casts to avoid empty keys incustom_ops_to_cast.Not functionally wrong, but staying consistent reduces downstream branching and noise.
Apply:
- custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast} + if inp_precision_cast or out_precision_cast: + custom_ops_to_cast[op_type] = { + "inp": inp_precision_cast, + "out": out_precision_cast, + }
376-383: Fix logging to show the originally requested precision.
precisionis overwritten before logging, so the warning prints equal values.Apply:
- if precision in ["int8", "fp8"]: - if precision != quantize_mode: - precision = quantize_mode - logger.warning( - f"Requested custom op precision ({precision}) is different than quantize mode: " - f"{quantize_mode}. Mixed {precision}+{quantize_mode} precision is not yet supported. " - f"Setting the custom op precision to be the same as quantize mode." - ) + if precision in ["int8", "fp8"]: + if precision != quantize_mode: + requested_precision = precision + precision = quantize_mode + logger.warning( + f"Requested custom op precision ({requested_precision}) is different than quantize mode: " + f"{quantize_mode}. Mixed {requested_precision}+{quantize_mode} precision is not yet supported. " + f"Setting the custom op precision to be the same as quantize mode." + )
388-394: Trim whitespace when parsing per-IO precisions.Avoids false “unsupported” warnings for inputs like
[fp16, fp8 ].Apply:
- inp_precision, out_precision = precision.split(":") - inp_precision = inp_precision.strip("[]").split(",") - out_precision = out_precision.strip("[]").split(",") + raw_inp, raw_out = precision.split(":") + inp_precision = [p.strip() for p in raw_inp.strip("[]").split(",") if p.strip()] + out_precision = [p.strip() for p in raw_out.strip("[]").split(",") if p.strip()]
395-400: Prefer explicit exceptions overassertfor user input validation.Asserts can be stripped with optimization and yield generic errors.
Apply:
- assert len(inp_precision) == num_inps, ( - f"Number of inputs doesn't match expectation: {len(inp_precision)} vs {num_inps}." - ) - assert len(out_precision) == num_outs, ( - f"Number of outputs doesn't match expectation: {len(out_precision)} vs {num_outs}." - ) + if len(inp_precision) != num_inps: + raise ValueError( + f"Number of inputs doesn't match expectation: {len(inp_precision)} vs {num_inps} for op '{op_type}'." + ) + if len(out_precision) != num_outs: + raise ValueError( + f"Number of outputs doesn't match expectation: {len(out_precision)} vs {num_outs} for op '{op_type}'." + )
218-224: Avoid mutable default args.
trt_plugins=[]andall_tensor_info={}are mutable defaults.Apply:
-def infer_types_shapes_tensorrt( - model: onnx.ModelProto, - trt_plugins: list[str] = [], - all_tensor_info: dict = {}, +def infer_types_shapes_tensorrt( + model: onnx.ModelProto, + trt_plugins: list[str] | None = None, + all_tensor_info: dict | None = None, @@ - if not all_tensor_info: - _, all_tensor_info = get_custom_layers(model, trt_plugins, strongly_typed) + if not all_tensor_info: + _, all_tensor_info = get_custom_layers(model, trt_plugins or [], strongly_typed)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
modelopt/onnx/quantization/qdq_utils.py(2 hunks)modelopt/onnx/trt_utils.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/onnx/quantization/qdq_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: gpu-tests-pr
- GitHub Check: partial-install (torch)
- GitHub Check: multi-transformers (min)
- GitHub Check: multi-torch (26)
- GitHub Check: multi-torch (27)
- GitHub Check: windows
🔇 Additional comments (2)
modelopt/onnx/trt_utils.py (2)
419-423: Good guard: only add quantize entries when non-empty.Prevents empty
{inp: [], out: []}entries incustom_ops_to_quantizeand aligns behavior with intent. Nice fix.
419-423: The scripts above will print the interpreter function and examine howcustom_ops_to_quantizeis used downstream. Once you provide the output, I’ll confirm whether any direct indexing could raise KeyError or if all usages safely guard missing keys.
What does this PR do?
Type of change: Bug fix
Overview: Fixed issue when quantizing constant inputs in custom ops. This was caused by the DQ in input and Q in output removal function assuming that there was a node before and after the QDQs for graph edge rewiring. That's fixed by this PR.
Usage
Can be used with any model with custom ops.
Testing
Before your PR is "Ready for review"
Additional Information
None
Summary by CodeRabbit
Bug Fixes
Chores